"""
Phase 5: Robustness battery for the commodity demographics paper.

Tests whether the core finding (Z1 x resource_rents_gdp interaction = +0.154***)
survives six standard robustness checks:
  1. Fixed effects (year FE, two-way FE)
  2. Predetermined demographics (5-year lagged Z1)
  3. OECD vs non-OECD sample split
  4. Cluster bootstrap (500 iterations)
  5. Permutation test (500 iterations)
  6. Leave-one-region-out jackknife
"""

import sys
import numpy as np
import pandas as pd
from pathlib import Path

sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')
from model import PanelGLS

# ── Paths ─────────────────────────────────────────────────────────────────
DATA_PATH = Path('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel_with_resources.csv')
OUTPUT_DIR = Path('/mnt/c/demographics_capital_flows/commodity_demographics/output/tables')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# ── Variable definitions ──────────────────────────────────────────────────
CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
BASE_VARS = ['Z_1', 'Z_2', 'Z_3'] + CONTROLS
CORE_VARS = BASE_VARS + ['resource_rents_gdp', 'Z1_x_resource']
INTERACTION_VAR = 'Z1_x_resource'

# ── OECD country list ────────────────────────────────────────────────────
OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
}

# ── World Bank region mapping ────────────────────────────────────────────
REGION_MAP = {}

_eap = [
    'AUS', 'BRN', 'CHN', 'FJI', 'HKG', 'IDN', 'JPN', 'KHM', 'KOR', 'LAO',
    'MAC', 'MMR', 'MNG', 'MYS', 'NZL', 'PHL', 'PNG', 'SGP', 'SLB', 'THA',
    'TLS', 'TON', 'TWN', 'VNM', 'VUT', 'WSM',
]
_eca = [
    'ALB', 'ARM', 'AUT', 'AZE', 'BEL', 'BGR', 'BIH', 'BLR', 'CHE', 'CYP',
    'CZE', 'DEU', 'DNK', 'ESP', 'EST', 'FIN', 'FRA', 'GBR', 'GEO', 'GRC',
    'HRV', 'HUN', 'IRL', 'ISL', 'ITA', 'KAZ', 'KGZ', 'LTU', 'LUX', 'LVA',
    'MDA', 'MKD', 'MLT', 'NLD', 'NOR', 'POL', 'PRT', 'ROU', 'RUS', 'SVK',
    'SVN', 'SWE', 'TJK', 'TKM', 'TUR', 'UKR', 'UZB',
]
_lac = [
    'ABW', 'ARG', 'BLZ', 'BOL', 'BRA', 'CHL', 'COL', 'CRI', 'CUB', 'DMA',
    'DOM', 'ECU', 'GRD', 'GTM', 'GUY', 'HND', 'HTI', 'JAM', 'LCA', 'MEX',
    'NIC', 'PAN', 'PER', 'PRY', 'SLV', 'STP', 'SUR', 'TTO', 'URY', 'VCT',
    'VEN',
]
_mena = [
    'ARE', 'BHR', 'DJI', 'DZA', 'EGY', 'IRN', 'IRQ', 'ISR', 'JOR', 'KWT',
    'LBN', 'LBY', 'MAR', 'OMN', 'PSE', 'QAT', 'SAU', 'SYR', 'TUN', 'YEM',
]
_na = ['CAN', 'USA']
_sa = ['AFG', 'BGD', 'BTN', 'IND', 'LKA', 'MDV', 'NPL', 'PAK']
_ssa = [
    'AGO', 'BDI', 'BEN', 'BFA', 'BWA', 'CAF', 'CIV', 'CMR', 'COG', 'COM',
    'CPV', 'ETH', 'GAB', 'GHA', 'GIN', 'GMB', 'GNB', 'GNQ', 'KEN', 'LBR',
    'LSO', 'MDG', 'MLI', 'MOZ', 'MRT', 'MUS', 'MWI', 'NAM', 'NER', 'NGA',
    'RWA', 'SDN', 'SEN', 'SLE', 'SWZ', 'SYC', 'TCD', 'TGO', 'TZA', 'UGA',
    'ZAF', 'ZMB', 'ZWE',
]

for _iso, _region in [
    (_eap, 'East Asia & Pacific'),
    (_eca, 'Europe & Central Asia'),
    (_lac, 'Latin America & Caribbean'),
    (_mena, 'Middle East & North Africa'),
    (_na, 'North America'),
    (_sa, 'South Asia'),
    (_ssa, 'Sub-Saharan Africa'),
]:
    for c in _iso:
        # Some countries appear in multiple lists (e.g. MEX in LAC and NA).
        # Use the first assignment, which follows WB convention.
        if c not in REGION_MAP:
            REGION_MAP[c] = _region


# ══════════════════════════════════════════════════════════════════════════
# Helper functions
# ══════════════════════════════════════════════════════════════════════════

def load_data():
    """Load and filter the panel to year <= 2024, create interaction."""
    df = pd.read_csv(DATA_PATH)
    df = df[df['year'] <= 2024].copy()
    # Ensure interaction term exists
    df['Z1_x_resource'] = df['Z_1'] * df['resource_rents_gdp']
    return df


def prepare(df, vars_needed=None):
    """Drop rows with NaN in required columns, return clean copy."""
    if vars_needed is None:
        vars_needed = CORE_VARS
    cols = ['ca_gdp', 'iso3', 'year'] + vars_needed
    sub = df.dropna(subset=['ca_gdp'] + vars_needed).copy()
    return sub


def run_core(df):
    """Run the core specification; return fitted PanelGLS model."""
    sub = prepare(df)
    gls = PanelGLS()
    gls.fit(sub['ca_gdp'].values, sub[CORE_VARS].values,
            sub['iso3'].values, sub['year'].values)
    return gls, sub


def interaction_stats(model, var_list=CORE_VARS):
    """Extract interaction coefficient, SE, p-value from a fitted model."""
    idx = var_list.index(INTERACTION_VAR)
    return model.beta[idx], model.se[idx], model.pvalues[idx]


# ══════════════════════════════════════════════════════════════════════════
# Baseline
# ══════════════════════════════════════════════════════════════════════════

def run_baseline(df):
    """Run baseline core model and return results row + model."""
    model, sub = run_core(df)
    beta, se, p = interaction_stats(model)
    print(f"\n[Baseline] N={model.n_obs}, R2={model.r_squared:.4f}, "
          f"Z1xResource={beta:.4f} (SE={se:.4f}, p={p:.4f})")
    return {
        'test': 'Baseline',
        'specification': 'Core model',
        'n_obs': model.n_obs,
        'r_squared': round(model.r_squared, 4),
        'interaction_beta': round(beta, 4),
        'interaction_se': round(se, 4),
        'interaction_p': round(p, 4),
        'notes': '',
    }, model


# ══════════════════════════════════════════════════════════════════════════
# 1. Fixed Effects Robustness
# ══════════════════════════════════════════════════════════════════════════

def run_fe_robustness(df, baseline_beta):
    """Year FE and two-way (entity + year) FE."""
    rows = []
    sub = prepare(df)

    # ── 1a. Year FE: demean by year, then PanelGLS (which adds AR1) ──
    print("\n[1a] Year FE ...")
    year_means = sub.groupby('year')[['ca_gdp'] + CORE_VARS].transform('mean')
    sub_yr = sub.copy()
    sub_yr['ca_gdp'] = sub['ca_gdp'] - year_means['ca_gdp']
    for v in CORE_VARS:
        sub_yr[v] = sub[v] - year_means[v]

    gls_yr = PanelGLS()
    gls_yr.fit(sub_yr['ca_gdp'].values, sub_yr[CORE_VARS].values,
               sub_yr['iso3'].values, sub_yr['year'].values)
    beta, se, p = interaction_stats(gls_yr)
    attn = (1 - beta / baseline_beta) * 100 if baseline_beta != 0 else np.nan
    print(f"   N={gls_yr.n_obs}, R2={gls_yr.r_squared:.4f}, "
          f"Z1xResource={beta:.4f}, attenuation={attn:.1f}%")
    rows.append({
        'test': 'Fixed Effects',
        'specification': 'Year FE (demean by year)',
        'n_obs': gls_yr.n_obs,
        'r_squared': round(gls_yr.r_squared, 4),
        'interaction_beta': round(beta, 4),
        'interaction_se': round(se, 4),
        'interaction_p': round(p, 4),
        'notes': f'Attenuation from baseline: {attn:.1f}%',
    })

    # ── 1b. Two-way FE: within-transform (entity + year) then OLS ──
    print("[1b] Two-way FE ...")
    entity_means = sub.groupby('iso3')[['ca_gdp'] + CORE_VARS].transform('mean')
    grand_mean = sub[['ca_gdp'] + CORE_VARS].mean()

    sub_tw = sub.copy()
    sub_tw['ca_gdp'] = sub['ca_gdp'] - entity_means['ca_gdp'] - year_means['ca_gdp'] + grand_mean['ca_gdp']
    for v in CORE_VARS:
        sub_tw[v] = sub[v] - entity_means[v] - year_means[v] + grand_mean[v]

    # OLS on within-transformed data (FE already removed)
    import statsmodels.api as sm
    X_tw = sm.add_constant(sub_tw[CORE_VARS].values)
    ols = sm.OLS(sub_tw['ca_gdp'].values, X_tw).fit()
    idx = CORE_VARS.index(INTERACTION_VAR) + 1  # +1 for constant
    beta_tw = ols.params[idx]
    se_tw = ols.bse[idx]
    p_tw = ols.pvalues[idx]
    r2_tw = ols.rsquared
    attn_tw = (1 - beta_tw / baseline_beta) * 100 if baseline_beta != 0 else np.nan
    print(f"   N={ols.nobs:.0f}, R2={r2_tw:.4f}, "
          f"Z1xResource={beta_tw:.4f}, attenuation={attn_tw:.1f}%")
    rows.append({
        'test': 'Fixed Effects',
        'specification': 'Two-way FE (entity + year within)',
        'n_obs': int(ols.nobs),
        'r_squared': round(r2_tw, 4),
        'interaction_beta': round(beta_tw, 4),
        'interaction_se': round(se_tw, 4),
        'interaction_p': round(p_tw, 4),
        'notes': f'Attenuation from baseline: {attn_tw:.1f}%',
    })

    return rows


# ══════════════════════════════════════════════════════════════════════════
# 2. Predetermined Demographics (5-year lagged Z1)
# ══════════════════════════════════════════════════════════════════════════

def run_predetermined(df):
    """Use Z_1 lagged 5 years within each country."""
    print("\n[2] Predetermined demographics (Z1 lagged 5 years) ...")
    sub = df.copy()
    sub = sub.sort_values(['iso3', 'year'])
    sub['Z_1_lag5'] = sub.groupby('iso3')['Z_1'].shift(5)
    sub['Z1_lag5_x_resource'] = sub['Z_1_lag5'] * sub['resource_rents_gdp']

    lag_vars = ['Z_1_lag5', 'Z_2', 'Z_3'] + CONTROLS + ['resource_rents_gdp', 'Z1_lag5_x_resource']
    sub = sub.dropna(subset=['ca_gdp'] + lag_vars)

    gls = PanelGLS()
    gls.fit(sub['ca_gdp'].values, sub[lag_vars].values,
            sub['iso3'].values, sub['year'].values)
    idx = lag_vars.index('Z1_lag5_x_resource')
    beta = gls.beta[idx]
    se = gls.se[idx]
    p = gls.pvalues[idx]
    print(f"   N={gls.n_obs}, R2={gls.r_squared:.4f}, "
          f"Z1_lag5 x Resource={beta:.4f} (SE={se:.4f}, p={p:.4f})")
    return {
        'test': 'Predetermined Demographics',
        'specification': 'Z1 lagged 5 years',
        'n_obs': gls.n_obs,
        'r_squared': round(gls.r_squared, 4),
        'interaction_beta': round(beta, 4),
        'interaction_se': round(se, 4),
        'interaction_p': round(p, 4),
        'notes': 'Z1_lag5 x resource_rents_gdp',
    }


# ══════════════════════════════════════════════════════════════════════════
# 3. OECD vs Non-OECD Split
# ══════════════════════════════════════════════════════════════════════════

def run_oecd_split(df):
    """Run core model separately on OECD and non-OECD subsamples."""
    rows = []
    for label, mask_fn in [
        ('OECD', lambda d: d['iso3'].isin(OECD)),
        ('Non-OECD', lambda d: ~d['iso3'].isin(OECD)),
    ]:
        print(f"\n[3] Sample split: {label} ...")
        sub = df[mask_fn(df)].copy()
        sub = prepare(sub)
        if len(sub) < 50:
            print(f"   Skipped: only {len(sub)} obs")
            continue
        gls = PanelGLS()
        gls.fit(sub['ca_gdp'].values, sub[CORE_VARS].values,
                sub['iso3'].values, sub['year'].values)
        beta, se, p = interaction_stats(gls)
        print(f"   N={gls.n_obs}, R2={gls.r_squared:.4f}, "
              f"Z1xResource={beta:.4f} (SE={se:.4f}, p={p:.4f})")
        rows.append({
            'test': 'Sample Split',
            'specification': label,
            'n_obs': gls.n_obs,
            'r_squared': round(gls.r_squared, 4),
            'interaction_beta': round(beta, 4),
            'interaction_se': round(se, 4),
            'interaction_p': round(p, 4),
            'notes': f'{gls.n_countries} countries',
        })
    return rows


# ══════════════════════════════════════════════════════════════════════════
# 4. Cluster Bootstrap (500 iterations)
# ══════════════════════════════════════════════════════════════════════════

def run_cluster_bootstrap(df, n_iter=500):
    """Resample countries with replacement, keep all years per country."""
    print(f"\n[4] Cluster bootstrap ({n_iter} iterations) ...")
    sub = prepare(df)
    countries = sub['iso3'].unique()
    n_countries = len(countries)

    observed_model = PanelGLS()
    observed_model.fit(sub['ca_gdp'].values, sub[CORE_VARS].values,
                       sub['iso3'].values, sub['year'].values)
    obs_beta, obs_se, obs_p = interaction_stats(observed_model)

    np.random.seed(42)
    boot_betas = []
    boot_sig = 0

    for i in range(n_iter):
        sampled = np.random.choice(countries, size=n_countries, replace=True)
        # Build bootstrap sample: for each draw, take all rows of that country
        frames = []
        for j, c in enumerate(sampled):
            chunk = sub[sub['iso3'] == c].copy()
            # Relabel to avoid duplicate entity IDs
            chunk['iso3_boot'] = f'{c}_{j}'
            frames.append(chunk)
        boot_df = pd.concat(frames, ignore_index=True)

        try:
            gls = PanelGLS()
            gls.fit(boot_df['ca_gdp'].values, boot_df[CORE_VARS].values,
                    boot_df['iso3_boot'].values, boot_df['year'].values)
            beta_i, se_i, p_i = interaction_stats(gls)
            boot_betas.append(beta_i)
            if p_i < 0.05:
                boot_sig += 1
        except Exception:
            continue

        if (i + 1) % 100 == 0:
            print(f"   ... {i+1}/{n_iter}")

    boot_betas = np.array(boot_betas)
    boot_se = np.std(boot_betas, ddof=1)
    ci_lo = np.percentile(boot_betas, 2.5)
    ci_hi = np.percentile(boot_betas, 97.5)
    frac_sig = boot_sig / len(boot_betas)

    print(f"   Observed beta: {obs_beta:.4f}")
    print(f"   Bootstrap SE: {boot_se:.4f}")
    print(f"   95% CI: [{ci_lo:.4f}, {ci_hi:.4f}]")
    print(f"   Fraction p<0.05: {frac_sig:.3f}")

    return {
        'test': 'Cluster Bootstrap',
        'specification': f'{n_iter} iterations, country-level',
        'n_obs': observed_model.n_obs,
        'r_squared': round(observed_model.r_squared, 4),
        'interaction_beta': round(obs_beta, 4),
        'interaction_se': round(boot_se, 4),
        'interaction_p': round(1 - frac_sig, 4),
        'notes': f'95% CI: [{ci_lo:.4f}, {ci_hi:.4f}]; {frac_sig:.1%} of iterations p<0.05',
    }


# ══════════════════════════════════════════════════════════════════════════
# 5. Permutation Test (500 iterations)
# ══════════════════════════════════════════════════════════════════════════

def run_permutation_test(df, n_iter=500):
    """Shuffle Z_1 across countries within each year."""
    print(f"\n[5] Permutation test ({n_iter} iterations) ...")
    sub = prepare(df)

    # Observed coefficient
    observed_model = PanelGLS()
    observed_model.fit(sub['ca_gdp'].values, sub[CORE_VARS].values,
                       sub['iso3'].values, sub['year'].values)
    obs_beta, _, _ = interaction_stats(observed_model)
    print(f"   Observed Z1xResource beta: {obs_beta:.4f}")

    np.random.seed(123)
    perm_betas = []

    for i in range(n_iter):
        perm_df = sub.copy()
        # Shuffle Z_1 across countries within each year
        for yr in perm_df['year'].unique():
            mask = perm_df['year'] == yr
            z1_vals = perm_df.loc[mask, 'Z_1'].values.copy()
            np.random.shuffle(z1_vals)
            perm_df.loc[mask, 'Z_1'] = z1_vals
        # Recompute interaction
        perm_df['Z1_x_resource'] = perm_df['Z_1'] * perm_df['resource_rents_gdp']

        try:
            gls = PanelGLS()
            gls.fit(perm_df['ca_gdp'].values, perm_df[CORE_VARS].values,
                    perm_df['iso3'].values, perm_df['year'].values)
            beta_i, _, _ = interaction_stats(gls)
            perm_betas.append(beta_i)
        except Exception:
            continue

        if (i + 1) % 100 == 0:
            print(f"   ... {i+1}/{n_iter}")

    perm_betas = np.array(perm_betas)
    perm_p = np.mean(np.abs(perm_betas) >= np.abs(obs_beta))
    print(f"   Permutation p-value: {perm_p:.4f}")
    print(f"   Permutation mean beta: {np.mean(perm_betas):.4f}")
    print(f"   Permutation SD: {np.std(perm_betas):.4f}")

    return {
        'test': 'Permutation Test',
        'specification': f'{n_iter} iterations, shuffle Z1 within year',
        'n_obs': observed_model.n_obs,
        'r_squared': round(observed_model.r_squared, 4),
        'interaction_beta': round(obs_beta, 4),
        'interaction_se': round(np.std(perm_betas), 4),
        'interaction_p': round(perm_p, 4),
        'notes': f'Permutation p = {perm_p:.4f}; mean shuffled beta = {np.mean(perm_betas):.4f}',
    }


# ══════════════════════════════════════════════════════════════════════════
# 6. Leave-One-Region-Out Jackknife
# ══════════════════════════════════════════════════════════════════════════

def run_region_jackknife(df):
    """Drop each WB region in turn, re-estimate."""
    print("\n[6] Leave-one-region-out jackknife ...")
    sub = prepare(df)
    sub['region'] = sub['iso3'].map(REGION_MAP)
    # Countries not in the mapping get 'Other'
    sub['region'] = sub['region'].fillna('Other')

    regions = sorted(sub['region'].unique())
    rows = []
    for region in regions:
        drop_df = sub[sub['region'] != region].copy()
        if len(drop_df) < 50:
            continue
        gls = PanelGLS()
        gls.fit(drop_df['ca_gdp'].values, drop_df[CORE_VARS].values,
                drop_df['iso3'].values, drop_df['year'].values)
        beta, se, p = interaction_stats(gls)
        n_dropped = sub[sub['region'] == region]['iso3'].nunique()
        print(f"   Drop {region:<30s} => beta={beta:.4f}, SE={se:.4f}, "
              f"p={p:.4f}, N={gls.n_obs} (dropped {n_dropped} countries)")
        rows.append({
            'test': 'Region Jackknife',
            'specification': f'Drop {region}',
            'n_obs': gls.n_obs,
            'r_squared': round(gls.r_squared, 4),
            'interaction_beta': round(beta, 4),
            'interaction_se': round(se, 4),
            'interaction_p': round(p, 4),
            'notes': f'Dropped {n_dropped} countries',
        })
    return rows


# ══════════════════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════════════════

def main():
    print("=" * 70)
    print("COMMODITY DEMOGRAPHICS — ROBUSTNESS BATTERY")
    print("=" * 70)

    df = load_data()
    all_rows = []

    # Baseline
    baseline_row, baseline_model = run_baseline(df)
    baseline_beta = baseline_row['interaction_beta']
    all_rows.append(baseline_row)

    # 1. Fixed effects
    fe_rows = run_fe_robustness(df, baseline_beta)
    all_rows.extend(fe_rows)

    # 2. Predetermined demographics
    predet_row = run_predetermined(df)
    all_rows.append(predet_row)

    # 3. OECD / non-OECD split
    split_rows = run_oecd_split(df)
    all_rows.extend(split_rows)

    # 4. Cluster bootstrap
    boot_row = run_cluster_bootstrap(df, n_iter=500)
    all_rows.append(boot_row)

    # 5. Permutation test
    perm_row = run_permutation_test(df, n_iter=500)
    all_rows.append(perm_row)

    # 6. Region jackknife
    jack_rows = run_region_jackknife(df)
    all_rows.extend(jack_rows)

    # ── Save ──────────────────────────────────────────────────────────────
    results = pd.DataFrame(all_rows)
    out_path = OUTPUT_DIR / 'table_robustness.csv'
    results.to_csv(out_path, index=False)
    print(f"\nSaved: {out_path}")

    # ── Summary ───────────────────────────────────────────────────────────
    print("\n" + "=" * 70)
    print("ROBUSTNESS SUMMARY")
    print("=" * 70)
    pd.set_option('display.width', 140)
    pd.set_option('display.max_colwidth', 60)
    print(results.to_string(index=False))


if __name__ == '__main__':
    main()
